import os
import json
import math
import torch
import numpy
import argparse
from scipy.io import arff
# import weka.core.jvm
# import weka.core.converters
import re
import copy
from collections import Counter
from collections import defaultdict
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn import metrics
from scipy.spatial.distance import cdist
from numpy import dot
from numpy.linalg import norm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

import scikit_wrappers_0 as scikit_wrappers
import pickle
from sklearn.model_selection import train_test_split
from sklearn.cluster import kmeans_plusplus
from filterpy.kalman import KalmanFilter
import time
from Common_functions import *

min_round_before_agg = 190
num_rounds = 200
batch_size = 64
num_clusters = 3;
Cluster_measurements = [[]] * num_clusters;
total_iters = [2,1,2];
save_path = 'Save_models/';

for t in numpy.arange(min_round_before_agg, num_rounds):
    with open(save_path + 'No_mem_evo_SVM_Cluster_Agg_weights' + str(t) + '_bs_' + str(
            batch_size) + '.pkl', 'rb') as fp:
        weights_loaded = pickle.load(fp)

    with open(save_path + 'No_mem_evo_SVM_Cluster_Agg_intercept' + str(t) + '_bs_' + str(
            batch_size) + '.pkl', 'rb') as fp:
        intercept_loaded = pickle.load(fp)

    for i in range(num_clusters):
        y_i = numpy.concatenate(
            (weights_loaded[i], numpy.expand_dims(intercept_loaded[i], 1)),
            axis=-1);
        y_i = numpy.reshape(y_i, (3210, 1))

        Cluster_measurements[i] += [y_i]


def calc_acc_kf(idx_list, w):
    client_accs = []
    client_accs_2 = []
    for i in idx_list:
        print('User: ', i)
        test = test_[i]
        test_labels = test_labels_[i]

        test_label_dict = defaultdict(list)
        for k in range(10):
            test_label_dict[k] = []
            for j in range(len(test_labels)):
                if test_labels[j] == k:
                    test_label_dict[k].append(j)

        round_accs = [];
        round_accs_2 = [];

        for t in [59]:
            print(t)
            model_load_evo = torch.load(save_path + 'NoMemEvo_Agg_Batchwise_EXPRTD_round_' + str(t) + '; Client' + str(
                i) + ' model_cluster_wise_save.pt')
            model_load_evo.encoder.eval()

            model_load_evo.classifier.coef_ = copy.copy(w[:, 0:320])
            model_load_evo.classifier.intercept_ = copy.copy(w[:, 320])

            acc = classifier_score_modded(model_load_evo, test, test_labels)
            round_accs += [acc]

            model_load_mem_evo = torch.load(save_path + 'Mem_Agg_Batchwise_EXPRTD_round_' + str(t) + '; Client' + str(
                i) + ' model_cluster_wise_save_v2_app_1.pt')
            model_load_mem_evo.encoder.eval()
            acc = classifier_score_modded(model_load_mem_evo, test, test_labels)
            round_accs_2 += [acc]

        client_accs += [round_accs]
        client_accs_2 += [round_accs_2]

    return numpy.mean(client_accs), numpy.mean(client_accs_2)


def perform_filtering(Cluster_1_measurements, set_F_to_eye=True, Sigma_0_Scale=1, total_iters=15, set_Sigma_to_P=True,
                      init_R=0.01*numpy.eye(3210), init_Q= 0.0001*numpy.eye(3210), test_ixd=[0, 1, 2]):
    filter_1 = KalmanFilter(dim_x=3210, dim_z=3210)
    filter_1.F = numpy.eye(3210);
    filter_1.H = numpy.eye(3210);
    filter_1.Q = init_Q
    filter_1.R = init_R
    if not set_Sigma_to_P:
        filter_1.P = Sigma_0_Scale * numpy.eye(3210);
    else:
        filter_1.P = numpy.eye(3210);

    # start_time=time.time()
    for k in range(total_iters):
        print("Iter: ", k)
        x_priors = [];
        P_priors = [];
        x_posts = [filter_1.x];
        P_posts = [0.5 * (filter_1.P + filter_1.P.T)];

        F = copy.copy(filter_1.F)
        Q = copy.copy(filter_1.Q)
        R = copy.copy(filter_1.R)

        for t in numpy.arange(min_round_before_agg, num_rounds):
            filter_1.predict()
            x_priors += [filter_1.x_prior];
            P_priors += [0.5 * (filter_1.P_prior + filter_1.P_prior.T)];
            filter_1.update(Cluster_1_measurements[t - min_round_before_agg])
            x_posts += [filter_1.x];
            P_posts += [0.5 * (filter_1.P + filter_1.P.T)];

        x_s = x_posts[-1];
        P_s = P_posts[-1];

        x_smooths = [x_s]
        P_smooths = [0.5 * (P_s + P_s.T)]

        K_n = numpy.matmul(P_priors[-1], numpy.linalg.inv(P_priors[-1] + R));

        for t in numpy.arange(len(x_posts) - 2, -1, -1):
            x__ = x_priors[t]
            P__ = P_priors[t]

            G = numpy.matmul(P_posts[t], F.T)
            G = numpy.matmul(G, numpy.linalg.inv(P__))

            x_s = x_posts[t] + numpy.matmul(G, x_s - x__);
            x_smooths = [x_s] + x_smooths

            P_s = P_posts[t] + numpy.matmul(numpy.matmul(G, P_s - P__), G.T)
            P_smooths = [0.5 * (P_s + P_s.T)] + P_smooths

        P_t_t_m_1_list = [numpy.matmul(numpy.eye(3210) - K_n, numpy.matmul(F, P_posts[-2]))]

        for t in numpy.arange(len(x_posts) - 2, 0, -1):
            G_t_m_2 = numpy.matmul(P_posts[t - 1], F.T)
            G_t_m_2 = numpy.matmul(G_t_m_2, numpy.linalg.inv(P_priors[t - 1]))

            G_t_m_1 = numpy.matmul(P_posts[t], F.T)
            G_t_m_1 = numpy.matmul(G_t_m_1, numpy.linalg.inv(P_priors[t]))

            P_t_1_t_2 = numpy.matmul(P_posts[t], G_t_m_2.T) + numpy.matmul(
                numpy.matmul(G_t_m_1, P_t_t_m_1_list[0] - numpy.matmul(F, P_posts[t])), G_t_m_2.T)
            P_t_t_m_1_list = [P_t_1_t_2] + P_t_t_m_1_list

        A = 0;
        B = 0;
        C = 0;
        R = 0;
        for t in numpy.arange(1, len(x_smooths)):
            A += P_smooths[t - 1] + numpy.matmul(x_smooths[t - 1], x_smooths[t - 1].T);
            B += P_t_t_m_1_list[t - 1] + numpy.matmul(x_smooths[t], x_smooths[t - 1].T)
            C += P_smooths[t] + numpy.matmul(x_smooths[t], x_smooths[t].T);
            R += numpy.matmul(Cluster_1_measurements[t - 1] - x_smooths[t],
                              (Cluster_1_measurements[t - 1] - x_smooths[t]).T) + P_smooths[t]

        del filter_1
        filter_1 = KalmanFilter(dim_x=3210, dim_z=3210)
        filter_1.H = numpy.eye(3210);
        if not set_F_to_eye:
            filter_1.F = numpy.matmul(B, numpy.linalg.inv(A))
        else:
            filter_1.F = numpy.eye(3210)
        Q_temp = 1 / len(Cluster_1_measurements) * (C - numpy.matmul(filter_1.F, B.T))
        R_temp = 1 / len(Cluster_1_measurements) * R
        filter_1.Q = 0.5 * (Q_temp + Q_temp.T)
        filter_1.R = 0.5 * (R_temp + R_temp.T)
        filter_1.x = x_smooths[0];
        if not set_Sigma_to_P:
            filter_1.P = Sigma_0_Scale * numpy.eye(3210);
        else:
            filter_1.P = P_smooths[0];

        # w_smooths = numpy.reshape(x_smooths[-1], (10, 321))
        # mean_acc_on_test_idxs, mean_acc_baseline =calc_acc_kf(test_ixd, w_smooths)

        print("stop here")
        print(numpy.linalg.norm(numpy.linalg.inv(P__)))

        # print("X Smooths", x_smooths)
        # print("Cluster_Measurements", Cluster_1_measurements)

    return x_smooths, filter_1.R, filter_1.Q


scale = [];
for c in range(num_clusters):
    scale += [numpy.mean([numpy.var(Cluster_measurements[c][i]) for i in range(len(Cluster_measurements[c]))])]

x_smooths = [];
R_s = [];
Q_s = [];

for c in range(num_clusters):
    x, r, q = perform_filtering(Cluster_measurements[c], Sigma_0_Scale=0.0001*scale[c], set_Sigma_to_P=False,
                                set_F_to_eye=False, test_ixd=[0, 1, 2], total_iters=total_iters[c])
    x_smooths +=[x];
    R_s += [r]
    Q_s += [q]

w_smooths = [];
for c in range(num_clusters):
    w_smooths += [numpy.reshape(x_smooths[c][-1], (10, 321))]

with open(save_path + 'No_mem_snapshot_SVM_Cluster_list' + str(num_rounds - 1) + '_bs_' + str(
        batch_size) + '.pkl', 'rb') as fp:
    client_lists_loaded = pickle.load(fp)


with open('test_x', 'rb') as fp:
    test_ = pickle.load(fp)
with open('test_y', 'rb') as fp:
    test_labels_ = pickle.load(fp)

local_model = torch.load(save_path + 'SVM_output_layer_Model_for_client_' + str(0) + '.pt')
local_model.encoder.eval()

acc = 0;
count = 0;
for c in range(num_clusters):
    for client_id in client_lists_loaded[c]:

        test = test_[client_id]
        test_labels = test_labels_[client_id]

        test_label_dict = defaultdict(list)
        for k in range(10):
            test_label_dict[k] = []
            for j in range(len(test_labels)):
                if test_labels[j] == k:
                    test_label_dict[k].append(j)

        feats = local_model.encode(test);
        local_model.classifier.coef_ = w_smooths[c][:, 0:-1]
        local_model.classifier.intercept_ = w_smooths[c][:, -1]
        count += 1
        acc += classifier_score_modded_feats(local_model, feats, test_labels)

acc=acc/count;
print("Total acc: ", acc)
with open(save_path + 'KF_performance_bs_' + str(batch_size) +'.pkl', 'wb') as fp:
    pickle.dump(acc, fp)